{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "7cc6013c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "PreTrainedTokenizer(name_or_path='bert-base-chinese', vocab_size=21128, model_max_len=512, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#第12章/加载编码工具\n",
    "from transformers import BertTokenizer\n",
    "\n",
    "token = BertTokenizer.from_pretrained('bert-base-chinese')\n",
    "\n",
    "token"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "0bc403a9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['text', 'label'],\n",
       "        num_rows: 9600\n",
       "    })\n",
       "    validation: Dataset({\n",
       "        features: ['text', 'label'],\n",
       "        num_rows: 0\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['text', 'label'],\n",
       "        num_rows: 1200\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#第12章/加载数据集\n",
    "from datasets import load_from_disk\n",
    "\n",
    "dataset = load_from_disk('./data/ChnSentiCorp')\n",
    "\n",
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "640c99c9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'cpu'"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#第12章/定义计算设备\n",
    "import torch\n",
    "\n",
    "device = 'cpu'\n",
    "if torch.cuda.is_available():\n",
    "    device = 'cuda'\n",
    "\n",
    "device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "ce069966",
   "metadata": {},
   "outputs": [],
   "source": [
    "#第12章/数据整理函数\n",
    "def collate_fn(data):\n",
    "    sents = [i['text'] for i in data]\n",
    "    labels = [i['label'] for i in data]\n",
    "\n",
    "    #编码\n",
    "    data = token.batch_encode_plus(batch_text_or_text_pairs=sents,\n",
    "                                   truncation=True,\n",
    "                                   padding=True,\n",
    "                                   max_length=512,\n",
    "                                   return_tensors='pt')\n",
    "\n",
    "    #转移到计算设备\n",
    "    for k, v in data.items():\n",
    "        data[k] = v.to(device)\n",
    "\n",
    "    data['labels'] = torch.LongTensor(labels).to(device)\n",
    "\n",
    "    return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "1f131014",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "600"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#第12章/数据加载器\n",
    "loader = torch.utils.data.DataLoader(dataset=dataset['train'],\n",
    "                                     batch_size=16,\n",
    "                                     collate_fn=collate_fn,\n",
    "                                     shuffle=True,\n",
    "                                     drop_last=True)\n",
    "\n",
    "len(loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "05c9e025",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "input_ids torch.Size([16, 235])\n",
      "token_type_ids torch.Size([16, 235])\n",
      "attention_mask torch.Size([16, 235])\n",
      "labels torch.Size([16])\n"
     ]
    }
   ],
   "source": [
    "#第12章/查看数据样例\n",
    "for i, data in enumerate(loader):\n",
    "    break\n",
    "\n",
    "for k, v in data.items():\n",
    "    print(k, v.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "b60fa9a6",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias']\n",
      "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-chinese and are newly initialized: ['classifier.weight', 'classifier.bias']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10226.9186\n"
     ]
    }
   ],
   "source": [
    "#第12章/加载预训练模型\n",
    "from transformers import AutoModelForSequenceClassification\n",
    "\n",
    "#加载模型\n",
    "model = AutoModelForSequenceClassification.from_pretrained('bert-base-chinese',\n",
    "                                                           num_labels=2)\n",
    "#设定计算设备\n",
    "model.to(device)\n",
    "\n",
    "#统计参数量\n",
    "print(sum(i.numel() for i in model.parameters()) / 10000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "b2c4220b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "BertForSequenceClassification(\n",
       "  (bert): BertModel(\n",
       "    (embeddings): BertEmbeddings(\n",
       "      (word_embeddings): Embedding(21128, 768, padding_idx=0)\n",
       "      (position_embeddings): Embedding(512, 768)\n",
       "      (token_type_embeddings): Embedding(2, 768)\n",
       "      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "      (dropout): Dropout(p=0.1, inplace=False)\n",
       "    )\n",
       "    (encoder): BertEncoder(\n",
       "      (layer): ModuleList(\n",
       "        (0): BertLayer(\n",
       "          (attention): BertAttention(\n",
       "            (self): BertSelfAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "            (output): BertSelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): BertIntermediate(\n",
       "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "            (intermediate_act_fn): GELUActivation()\n",
       "          )\n",
       "          (output): BertOutput(\n",
       "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "        )\n",
       "        (1): BertLayer(\n",
       "          (attention): BertAttention(\n",
       "            (self): BertSelfAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "            (output): BertSelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): BertIntermediate(\n",
       "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "            (intermediate_act_fn): GELUActivation()\n",
       "          )\n",
       "          (output): BertOutput(\n",
       "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "        )\n",
       "        (2): BertLayer(\n",
       "          (attention): BertAttention(\n",
       "            (self): BertSelfAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "            (output): BertSelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): BertIntermediate(\n",
       "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "            (intermediate_act_fn): GELUActivation()\n",
       "          )\n",
       "          (output): BertOutput(\n",
       "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "        )\n",
       "        (3): BertLayer(\n",
       "          (attention): BertAttention(\n",
       "            (self): BertSelfAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "            (output): BertSelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): BertIntermediate(\n",
       "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "            (intermediate_act_fn): GELUActivation()\n",
       "          )\n",
       "          (output): BertOutput(\n",
       "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "        )\n",
       "        (4): BertLayer(\n",
       "          (attention): BertAttention(\n",
       "            (self): BertSelfAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "            (output): BertSelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): BertIntermediate(\n",
       "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "            (intermediate_act_fn): GELUActivation()\n",
       "          )\n",
       "          (output): BertOutput(\n",
       "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "        )\n",
       "        (5): BertLayer(\n",
       "          (attention): BertAttention(\n",
       "            (self): BertSelfAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "            (output): BertSelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): BertIntermediate(\n",
       "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "            (intermediate_act_fn): GELUActivation()\n",
       "          )\n",
       "          (output): BertOutput(\n",
       "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "        )\n",
       "        (6): BertLayer(\n",
       "          (attention): BertAttention(\n",
       "            (self): BertSelfAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "            (output): BertSelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): BertIntermediate(\n",
       "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "            (intermediate_act_fn): GELUActivation()\n",
       "          )\n",
       "          (output): BertOutput(\n",
       "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "        )\n",
       "        (7): BertLayer(\n",
       "          (attention): BertAttention(\n",
       "            (self): BertSelfAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "            (output): BertSelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): BertIntermediate(\n",
       "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "            (intermediate_act_fn): GELUActivation()\n",
       "          )\n",
       "          (output): BertOutput(\n",
       "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "        )\n",
       "        (8): BertLayer(\n",
       "          (attention): BertAttention(\n",
       "            (self): BertSelfAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "            (output): BertSelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): BertIntermediate(\n",
       "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "            (intermediate_act_fn): GELUActivation()\n",
       "          )\n",
       "          (output): BertOutput(\n",
       "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "        )\n",
       "        (9): BertLayer(\n",
       "          (attention): BertAttention(\n",
       "            (self): BertSelfAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "            (output): BertSelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): BertIntermediate(\n",
       "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "            (intermediate_act_fn): GELUActivation()\n",
       "          )\n",
       "          (output): BertOutput(\n",
       "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "        )\n",
       "        (10): BertLayer(\n",
       "          (attention): BertAttention(\n",
       "            (self): BertSelfAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "            (output): BertSelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): BertIntermediate(\n",
       "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "            (intermediate_act_fn): GELUActivation()\n",
       "          )\n",
       "          (output): BertOutput(\n",
       "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "        )\n",
       "        (11): BertLayer(\n",
       "          (attention): BertAttention(\n",
       "            (self): BertSelfAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "            (output): BertSelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): BertIntermediate(\n",
       "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "            (intermediate_act_fn): GELUActivation()\n",
       "          )\n",
       "          (output): BertOutput(\n",
       "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (pooler): BertPooler(\n",
       "      (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "      (activation): Tanh()\n",
       "    )\n",
       "  )\n",
       "  (dropout): Dropout(p=0.1, inplace=False)\n",
       "  (classifier): Linear(in_features=768, out_features=2, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "788cb5e5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor(0.7723, grad_fn=<NllLossBackward0>), torch.Size([16, 2]))"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#模型试算\n",
    "out = model(**data)\n",
    "\n",
    "out['loss'], out['logits'].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "6a36158a",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/anaconda3/envs/gpu/lib/python3.6/site-packages/transformers/optimization.py:309: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
      "  FutureWarning,\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 0.688614547252655 0.0004991666666666666 0.625\n",
      "10 0.6768248677253723 0.0004908333333333334 0.625\n",
      "20 0.7436589002609253 0.0004825 0.375\n",
      "30 0.5400702357292175 0.0004741666666666667 0.9375\n",
      "40 0.6189451217651367 0.00046583333333333334 0.6875\n",
      "50 0.5567441582679749 0.0004575 0.9375\n",
      "60 0.5188874006271362 0.00044916666666666667 0.875\n",
      "70 0.5120527148246765 0.0004408333333333334 0.875\n",
      "80 0.5961946249008179 0.0004325 0.6875\n",
      "90 0.5220378637313843 0.0004241666666666667 0.8125\n",
      "100 0.5031461715698242 0.0004158333333333333 0.875\n",
      "110 0.5246412754058838 0.0004075 0.75\n",
      "120 0.5258313417434692 0.0003991666666666667 0.875\n",
      "130 0.5714597105979919 0.0003908333333333333 0.75\n",
      "140 0.5225942134857178 0.00038250000000000003 0.8125\n",
      "150 0.4699893891811371 0.00037416666666666664 0.9375\n",
      "160 0.470580130815506 0.00036583333333333335 0.9375\n",
      "170 0.44881629943847656 0.0003575 0.9375\n",
      "180 0.42978283762931824 0.0003491666666666667 0.9375\n",
      "190 0.5476064682006836 0.00034083333333333334 0.875\n",
      "200 0.441872775554657 0.0003325 0.9375\n",
      "210 0.44583767652511597 0.00032416666666666667 0.875\n",
      "220 0.49078699946403503 0.0003158333333333334 0.8125\n",
      "230 0.47897177934646606 0.0003075 0.8125\n",
      "240 0.5153447389602661 0.0002991666666666667 0.8125\n",
      "250 0.4353978633880615 0.0002908333333333333 0.875\n",
      "260 0.47860240936279297 0.0002825 0.8125\n",
      "270 0.644160807132721 0.0002741666666666667 0.75\n",
      "280 0.5591180324554443 0.0002658333333333333 0.75\n",
      "290 0.37128186225891113 0.0002575 1.0\n",
      "300 0.5672366619110107 0.0002491666666666667 0.8125\n",
      "310 0.4191279113292694 0.00024083333333333335 0.9375\n",
      "320 0.6039121747016907 0.0002325 0.625\n",
      "330 0.5598694682121277 0.00022416666666666665 0.75\n",
      "340 0.4366607367992401 0.00021583333333333334 0.875\n",
      "350 0.4092235267162323 0.0002075 0.9375\n",
      "360 0.49602237343788147 0.00019916666666666667 0.75\n",
      "370 0.41506507992744446 0.00019083333333333333 0.875\n",
      "380 0.47769472002983093 0.0001825 0.8125\n",
      "390 0.4564906656742096 0.00017416666666666668 0.9375\n",
      "400 0.47447848320007324 0.00016583333333333334 0.9375\n",
      "410 0.5312456488609314 0.0001575 0.75\n",
      "420 0.41240501403808594 0.00014916666666666667 0.9375\n",
      "430 0.4089215397834778 0.00014083333333333333 0.9375\n",
      "440 0.5233736634254456 0.00013250000000000002 0.75\n",
      "450 0.4825798571109772 0.00012416666666666666 0.875\n",
      "460 0.5071223974227905 0.00011583333333333333 0.8125\n",
      "470 0.4656216502189636 0.0001075 0.875\n",
      "480 0.4673466086387634 9.916666666666667e-05 0.875\n",
      "490 0.5248121619224548 9.083333333333334e-05 0.8125\n",
      "500 0.4834817051887512 8.25e-05 0.8125\n",
      "510 0.46574878692626953 7.416666666666668e-05 0.875\n",
      "520 0.5053324103355408 6.583333333333333e-05 0.875\n",
      "530 0.4442726969718933 5.75e-05 0.875\n",
      "540 0.4078090190887451 4.9166666666666665e-05 0.875\n",
      "550 0.38644489645957947 4.0833333333333334e-05 1.0\n",
      "560 0.3533816337585449 3.2500000000000004e-05 1.0\n",
      "570 0.40406540036201477 2.4166666666666667e-05 0.9375\n",
      "580 0.40326061844825745 1.5833333333333336e-05 1.0\n",
      "590 0.3704525828361511 7.5e-06 0.9375\n"
     ]
    }
   ],
   "source": [
    "#第12章/训练\n",
    "from transformers import AdamW\n",
    "from transformers.optimization import get_scheduler\n",
    "\n",
    "\n",
    "def train():\n",
    "    #定义优化器\n",
    "    optimizer = AdamW(model.parameters(), lr=5e-4)\n",
    "\n",
    "    #定义学习率调节器\n",
    "    scheduler = get_scheduler(name='linear',\n",
    "                              num_warmup_steps=0,\n",
    "                              num_training_steps=len(loader),\n",
    "                              optimizer=optimizer)\n",
    "\n",
    "    #模型切换到训练模式\n",
    "    model.train()\n",
    "\n",
    "    #按批次遍历训练集中的数据\n",
    "    for i, data in enumerate(loader):\n",
    "\n",
    "        #模型计算\n",
    "        out = model(**data)\n",
    "\n",
    "        #计算loss并使用梯度下降法优化模型参数\n",
    "        out['loss'].backward()\n",
    "        optimizer.step()\n",
    "        scheduler.step()\n",
    "        optimizer.zero_grad()\n",
    "        model.zero_grad()\n",
    "\n",
    "        #输出各项数据的情况，便于观察\n",
    "        if i % 10 == 0:\n",
    "            out = out['logits'].argmax(dim=1)\n",
    "            accuracy = (out == labels).sum().item() / len(labels)\n",
    "            lr = optimizer.state_dict()['param_groups'][0]['lr']\n",
    "            print(i, loss.item(), lr, accuracy)\n",
    "\n",
    "\n",
    "train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "b8758dc9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "1\n",
      "2\n",
      "3\n",
      "4\n",
      "0.89375\n"
     ]
    }
   ],
   "source": [
    "#第12章/测试\n",
    "def test():\n",
    "    #定义测试数据集加载器\n",
    "    loader_test = torch.utils.data.DataLoader(dataset=Dataset('test'),\n",
    "                                              batch_size=32,\n",
    "                                              collate_fn=collate_fn,\n",
    "                                              shuffle=True,\n",
    "                                              drop_last=True)\n",
    "\n",
    "    #下游任务模型切换到运行模式\n",
    "    model.eval()\n",
    "    correct = 0\n",
    "    total = 0\n",
    "\n",
    "    #按批次遍历测试集中的数据\n",
    "    for i, (data) in enumerate(loader_test):\n",
    "\n",
    "        #计算5个批次即可，不需要全部遍历\n",
    "        if i == 5:\n",
    "            break\n",
    "\n",
    "        print(i)\n",
    "\n",
    "        #计算\n",
    "        with torch.no_grad():\n",
    "            out = model(**data)\n",
    "\n",
    "        #统计正确率\n",
    "        out = out['logits'].argmax(dim=1)\n",
    "        correct += (out == labels).sum().item()\n",
    "        total += len(labels)\n",
    "\n",
    "    print(correct / total)\n",
    "\n",
    "\n",
    "test()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
